import komand
from .schema import GetVulnerabilityAffectedAssetsInput, GetVulnerabilityAffectedAssetsOutput
# Custom imports below
import requests
from komand_rapid7_insightvm.util import endpoints
import json
from komand.exceptions import PluginException


class GetVulnerabilityAffectedAssets(komand.Action):

    _ERRORS = {
        401: "Unauthorized",
        404: "Not Found",
        500: "Internal Server Error",
        503: "Service Unavailable",
        000: "Unknown Status Code"
    }

    def __init__(self):
        super(self.__class__, self).__init__(
                name='get_vulnerability_affected_assets',
                description='Get the assets affected by the vulnerability',
                input=GetVulnerabilityAffectedAssetsInput(),
                output=GetVulnerabilityAffectedAssetsOutput())

    def run(self, params={}):
        vulnerability_id = params.get("vulnerability_id")

        endpoint = endpoints.Vulnerability.vulnerability_affected_assets(self.connection.console_url,
                                                                         vulnerability_id)

        results = self.get_assets(endpoint)
        try:
            links = results["links"]
        except KeyError:
            self.logger.info("Warning: No links returned in response. Using empty list.")
            links = list()
        try:
            resources = results["resources"]
        except KeyError:
            self.logger.info("Warning: No resources returned in response. Using empty list.")
            resources = list()

        return {"links": links, "resources": resources}

    def get_assets(self, endpoint: str):
        """
        Retrieves assets for a vulnerability
        :param endpoint: Endpoint to reach
        :return: JSON result in {links: []dic, resources: []int} format
        """
        try:
            response = self.connection.session.get(url=endpoint,
                                                   verify=False)
        except requests.RequestException as e:
            raise PluginException(preset=PluginException.Preset.UNKNOWN, data=str(e))

        if response.status_code in [200, 201]:  # 200 is documented, 201 is undocumented
            response_json = response.json()
            return response_json
        else:
            try:
                reason = response.json()["message"]
            except KeyError:
                reason = "Unknown error occurred. Please contact support or try again later."
            except json.decoder.JSONDecodeError:
                raise PluginException(preset=PluginException.Preset.INVALID_JSON, data=reason.text)

            status_code_message = self._ERRORS.get(response.status_code, self._ERRORS[000])
            self.logger.error("{status} ({code}): {reason}".format(status=status_code_message,
                                                                   code=response.status_code,
                                                                   reason=reason))
            raise PluginException(preset=PluginException.Preset.UNKNOWN)
